Source code for hysop.symbolic.base

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import sympy as sm

from hysop.tools.numpywrappers import npw
from hysop.tools.htypes import first_not_None
from hysop.symbolic import Symbol, Dummy, subscript
from hysop.tools.sympy_utils import (
    sstr,
    sstrrepr,
    latex as _latex,
    UnevaluatedExpr,
    UnsplittedExpr,
)
from contextlib import contextmanager


[docs] class ValueHolderI: """ Interface for classes that may hold a value that can be replaced in sympy expressions. """ def __new__(cls, *args, **kwds): return super().__new__(cls, *args, **kwds) def __init__(self, *args, **kwds): super().__init__(*args, **kwds)
[docs] def get_holded_value(self): """Get holded value, defaults to None.""" return None
[docs] @classmethod def get_holded_values(cls, expr): replace = {} def collect(expr): if isinstance(expr, ValueHolderI): val = expr.get_holded_value() if val is not None: replace[expr] = val elif isinstance(expr, sm.Expr): for e in expr.args: collect(e) collect(expr) return replace
[docs] @classmethod def replace_holded_values(cls, expr): replace = cls.get_holded_values(expr) try: return expr.xreplace(replace) except AttributeError: return expr
[docs] class ScalarDataViewHolder(ValueHolderI): def __new__(cls, holded_data_ref=None, holded_data_access=None, **kwds): if ( isinstance(holded_data_ref, npw.ndarray) and (holded_data_access is None) and (holded_data_ref.size == 1) ): holded_data_access = (0,) obj = super().__new__(cls, **kwds) obj._holded_value_ref = holded_data_ref obj._holded_data_access = holded_data_access return obj def __init__(self, holded_data_ref=None, holded_data_access=None, **kwds): super().__init__(**kwds)
[docs] def get_holded_value(self): if self._holded_value_ref is None: return None elif self._holded_data_access is None: return self._holded_value_ref elif callable(self._holded_data_access): return self._holded_data_access(self._holded_value_ref) else: return self._holded_value_ref[self._holded_data_access]
def _hashable_content(self): """See sympy.core.basic.Basic._hashable_content()""" hc = super()._hashable_content() hc += ( id(self._holded_value_ref), self._holded_data_access, ) return hc
[docs] class ScalarBaseTag: """Tag for object that can be inserted as element of tensors.""" def __new__(cls, idx=None, **kwds): obj = super().__new__(cls, **kwds) obj._idx = idx return obj def __init__(self, idx=None, **kwds): super().__init__(**kwds) @property def idx(self): return self._idx def _hashable_content(self): """See sympy.core.basic.Basic._hashable_content()""" hc = super()._hashable_content() hc += (self._idx,) return hc
[docs] class ScalarBase(ScalarDataViewHolder, ScalarBaseTag): """Base for symbolic scalars.""" def __new__(cls, name, value=None, view=None, **kwds): if value is not None: assert kwds.get("holded_data_ref", None) is None kwds["holded_data_ref"] = value if view is not None: assert kwds.get("holded_data_access", None) is None kwds["holded_data_access"] = view obj = super().__new__(cls, name=name, **kwds) obj._iterable = False return obj def __init__(self, name, value=None, view=None, **kwds): super().__init__(name=name, **kwds)
[docs] def vreplace(self): """Call ValueHolderI.replace_holded_values on self.""" return self.replace_holded_values(self)
def __getitem__(self, key): assert key == 0 return self
[docs] class TensorBase(npw.ndarray): """ Base for symbolic tensors. A tensor is a read-only npw.ndarray subclass containing symbolic scalars or symbolic expressions. """ __array_priority__ = 1.0
[docs] def __new__( cls, shape, init=None, name=None, pretty_name=None, scalar_cls=None, scalar_kwds=None, make_scalar_kwds=None, value=None, set_read_only=True, dtype=object, **kwds, ): """Create a new TensorBase.""" set_read_only = first_not_None(set_read_only, True) obj = super().__new__(cls, shape=shape, dtype=dtype, **kwds) if init is None: assert name is not None pretty_name = first_not_None(pretty_name, name) assert scalar_cls is not None assert issubclass(scalar_cls, ScalarBaseTag) scalar_kwds = first_not_None(scalar_kwds, {}) lsep = "" if npw.less(shape, 10).all() else "," vsep = "_" with obj.write_context(): for idx in npw.ndindex(*shape): name = "{}_{}".format(name, vsep.join(str(i) for i in idx)) pname = "{}{}".format( pretty_name, "".join(subscript(i) for i in idx) ) vname = "{}_{}".format(name, vsep.join(str(i) for i in idx)) lname = "{}_{{{}}}".format(name, lsep.join(str(i) for i in idx)) if make_scalar_kwds is None: skwds = scalar_kwds else: assert callable(make_scalar_kwds) idx_kwds = make_scalar_kwds(idx) for k in idx_kwds.keys(): msg = f"{k} was already set in scalar_kwds." assert k not in scalar_kwds, msg idx_kwds.update(scalar_kwds) skwds = idx_kwds obj[idx] = scalar_cls( name=name, pretty_name=pname, var_name=vname, latex_name=lname, value=value, idx=idx, **skwds, ) else: obj[...] = init return obj
def __init__( self, shape, init=None, name=None, pretty_name=None, scalar_cls=None, scalar_kwds=None, make_scalar_kwds=None, value=None, set_read_only=True, dtype=object, **kwds, ): super().__init__(**kwds)
[docs] def latex(self, matrix="b", with_packages=False): """ Return a latex representation of this tensor. """ assert self.ndim <= 2 ss = "" if with_packages: ss += r"\usepackage{amsmath}" ss += "\n$$" ss += "\n" + rf"\begin{{{matrix}matrix}}" for i in range(self.shape[0]): if self.ndim == 1: ss += "\n " + _latex(self[i]) + " \\\\" else: ss += "\n " + " & ".join(_latex(val) for val in self[i]) + " \\\\" ss += "\n" + rf"\end{{{matrix}matrix}}" ss += "\n$$" return ss
[docs] def sstr(self): return self.elementwise_fn(sstr)
[docs] def strrepr(self): return self.elementwise_fn(sstrrepr)
def __str__(self): if self.ndim == 0: return sstr(self.tolist()) if (self.ndim == 1) and (self.size > 1): # reshape as a vector a = self.reshape(self.shape + (1,)) else: a = self return npw.array2string(a, formatter={"all": lambda x: str(x)}, separator=" ") def __repr__(self): return npw.array2string(self, formatter={"all": lambda x: sstrrepr(x)})
[docs] @contextmanager def write_context(self): """ Temporarily grant write access to self for the duration of the context. Only usefull for tensors set as read-only. """ try: _old_flag = self.flags.writeable self.flags.writeable = True yield except: raise finally: self.flags.writeable = _old_flag
[docs] def elementwise_fn(self, fn): """ Apply function fn on each element of the tensor and return the result as a Tensor. """ if self.ndim: data = npw.empty_like(self) for idx in npw.ndindex(*self.shape): data[idx] = fn(self[idx]) else: data = fn(self.tolist()) return data
[docs] def __hash__(self): """Hash this object by its id.""" return id(self)
[docs] def diff(self, *symbols, **assumptions): """Elementwise sympy.diff().""" return self.elementwise_fn(lambda x: sm.diff(x, *symbols, **assumptions))
[docs] def freeze(self): """Apply elementwise UnevaluatedExpr on each scalar expressions.""" return self.elementwise_fn(lambda x: UnevaluatedExpr(x))
[docs] def no_split(self): """Apply elementwise UnsplittedExpr on each scalar expressions.""" return self.elementwise_fn(lambda x: UnsplittedExpr(x))
[docs] def simplify(self): """Elementwise sympy.simplify().""" return self.elementwise_fn(lambda x: sm.simplify(x))
[docs] def xreplace(self, replacements): """Elementwise sympy.xreplace().""" replace = {} for k, v in replacements.items(): if isinstance(k, npw.ndarray): for idx in npw.ndindex(*k.shape): kk = k[idx] if isinstance(v, npw.ndarray): assert k.shape == v.shape vv = v[idx] else: vv = v if (kk is not None) and (vv is not None): replace[kk] = vv elif (k is not None) and (v is not None): replace[k] = v data = npw.empty_like(self) for idx in npw.ndindex(*self.shape): data[idx] = self[idx].xreplace(replace) return data
[docs] def vreplace(self): """Elementwise ValueHolderI.replace_holded_values on self.""" data = npw.empty_like(self) for idx in npw.ndindex(*self.shape): data[idx] = ValueHolderI.replace_holded_values(self[idx]) return data
[docs] class SymbolicScalar(ScalarBase, Symbol): """Symbolic scalar symbol.""" pass
[docs] class DummySymbolicScalar(ScalarBase, Dummy): """Symbolic scalar dummy symbol.""" pass
[docs] class SymbolicTensor(TensorBase): """Symbolic tensor symbol.""" def __new__(cls, name, shape, init=None, scalar_cls=None, **kwds): scalar_cls = first_not_None(scalar_cls, SymbolicScalar) return super().__new__( cls, name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds ) def __init__(self, name, shape, init=None, scalar_cls=None, **kwds): super().__init__( name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds )
[docs] class DummySymbolicTensor(TensorBase): """Dummy symbolic tensor symbol.""" def __new__(cls, name, shape, init=None, scalar_cls=None, **kwds): scalar_cls = first_not_None(scalar_cls, DummySymbolicScalar) return super().__new__( cls, name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds ) def __init__(self, name, shape, init=None, scalar_cls=None, **kwds): super().__init__( name=name, shape=shape, init=init, scalar_cls=scalar_cls, **kwds )
[docs] def vreplace(expr): ValueHolderI.replace_holded_values(expr)
if __name__ == "__main__": a = SymbolicScalar("a", value=sm.Symbol("A")) b = DummySymbolicScalar( "a", value=sm.Symbol("B") ) # different symbol with the same name c = DummySymbolicScalar("a", value=[sm.Symbol("C0"), sm.Symbol("C1")], view=1) d = SymbolicScalar("a", value=sm.Symbol("D")) # same symbol as a (hashed by name) print(a + b + c + d) print(ValueHolderI.replace_holded_values(a + b + c + d)) print() A = SymbolicTensor("A", shape=(3, 3), value=12) B = SymbolicTensor("B", shape=(3, 3), set_read_only=False, value=npw.eye(3, 3)) C = DummySymbolicTensor("C", shape=(8,)) print(A) print(B) print(C) B[0, 0] = 0 B[1, 0] = -1 print() print(A) print(B) print() print(A.vreplace()) print(B.vreplace()) print() print(A * B) print() print((A.dot(B)).elementwise_fn(sm.cos)) print() print((A.dot(B)).elementwise_fn(sm.cos).diff(B[1, 1])) print() print(A.latex())